{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# !pip install -r requirements.txt"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "4349c9c03c52b3f4"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d814763380c19c4",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import xarray as xr\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "# path config\n",
    "feature_path = ''\n",
    "gt_path = ''\n",
    "years = ['2019']\n",
    "fcst_steps = list(range(1, 73, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class Feature:\n",
    "    def __init__(self):\n",
    "        self.path = feature_path\n",
    "        self.years = years\n",
    "        self.fcst_steps = fcst_steps\n",
    "        self.features_paths_dict = self.get_features_paths()\n",
    "\n",
    "    def get_features_paths(self):\n",
    "        init_time_path_dict = {}\n",
    "        for year in self.years:\n",
    "            init_time_dir_year = os.listdir(os.path.join(self.path, year))\n",
    "            for init_time in sorted(init_time_dir_year):\n",
    "                init_time_path_dict[pd.to_datetime(init_time)] = os.path.join(self.path, year, init_time)\n",
    "        return init_time_path_dict\n",
    "\n",
    "    def get_fts(self, init_time):\n",
    "        return xr.open_mfdataset(self.features_paths_dict.get(init_time) + '/*').sel(lead_time=self.fcst_steps).isel(\n",
    "            time=0)\n",
    "\n",
    "\n",
    "class GT:\n",
    "    def __init__(self):\n",
    "        self.path = gt_path\n",
    "        self.years = years\n",
    "        self.fcst_steps = fcst_steps\n",
    "        self.gt_paths = [os.path.join(self.path, f'{year}.nc') for year in self.years]\n",
    "        self.gts = xr.open_mfdataset(self.gt_paths)\n",
    "\n",
    "    def parser_gt_timestamps(self, init_time):\n",
    "        return [init_time + pd.Timedelta(f'{fcst_step}h') for fcst_step in self.fcst_steps]\n",
    "\n",
    "    def get_gts(self, init_time):\n",
    "        return self.gts.sel(time=self.parser_gt_timestamps(init_time))\n",
    "\n",
    "\n",
    "class mydataset(Dataset):\n",
    "    def __init__(self):\n",
    "        self.ft = Feature()\n",
    "        self.gt = GT()\n",
    "        self.features_paths_dict = self.ft.features_paths_dict\n",
    "        self.init_times = list(self.features_paths_dict.keys())\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        init_time = self.init_times[index]\n",
    "        ft_item = self.ft.get_fts(init_time).to_array().isel(variable=0).values\n",
    "        gt_item = self.gt.get_gts(init_time).to_array().isel(variable=0).values\n",
    "        return ft_item, gt_item\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(list(self.init_times))\n",
    "    \n",
    "# define dataset\n",
    "my_data = mydataset()\n",
    "print('sample num:', mydataset().__len__())\n",
    "train_loader = DataLoader(my_data, batch_size=1, shuffle=True)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "initial_id"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "class Model(nn.Module):\n",
    "    def __init__(self, num_in_ch, num_out_ch):\n",
    "        super(Model, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        B, S, C, W, H = tuple(x.shape)\n",
    "        x = x.reshape(B, -1, W, H)\n",
    "        out = self.conv1(x)\n",
    "        out = out.reshape(B, S, W, H)\n",
    "        return out\n",
    "\n",
    "# define model\n",
    "in_varibales = 24\n",
    "in_times = len(fcst_steps)\n",
    "out_varibales = 1\n",
    "out_times = len(fcst_steps)\n",
    "input_size = in_times * in_varibales\n",
    "output_size = out_times * out_varibales\n",
    "model = Model(input_size, output_size).cuda()"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "b7b879bc9aa3b221"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# define loss\n",
    "loss_func = nn.MSELoss()"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "12303299f9a3b7e6"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# train\n",
    "import numpy as np\n",
    "if __name__ == '__main__':\n",
    "    for index, (ft_item, gt_item) in enumerate(train_loader):\n",
    "        ft_item = ft_item.cuda().float()\n",
    "        gt_item = gt_item.cuda().float()\n",
    "        output_item = model(ft_item)\n",
    "        loss = loss_func(output_item, gt_item)\n",
    "        print(10*'*', np.round(loss.item(),4), 10*'*')"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "245090cbefa01701"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
